/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.torch

import scala.math._
import com.intel.analytics.bigdl.dllib.nn.{MSECriterion, Sequential, TimeDistributedCriterion}
import com.intel.analytics.bigdl.dllib.optim.{L2Regularizer, SGD}
import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
import com.intel.analytics.bigdl.dllib.utils.RandomGenerator._
import com.intel.analytics.bigdl.dllib.utils.{T, TestUtils}
import com.intel.analytics.bigdl.dllib.keras.layers.InternalConvLSTM2D
import com.intel.analytics.bigdl.dllib.keras.layers.internal.InternalRecurrent
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class InternalConvLSTMSpec extends FlatSpec with BeforeAndAfter with Matchers {

  "A InternalConvLSTM2D " should " work in BatchMode" in {
    val hiddenSize = 5
    val inputSize = 3
    val seqLength = 4
    val batchSize = 2
    val kernal = 3
    val rec = new InternalRecurrent[Double]()
    val model = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](
          inputSize,
          hiddenSize,
          kernal,
          1,
          padding = -1,
          withPeephole = false)))

    val input = Tensor[Double](batchSize, seqLength, inputSize, kernal, kernal).rand
    val output = model.forward(input).toTensor[Double]
    for (i <- 1 to 3) {
      val output = model.forward(input)
      model.backward(input, output)
    }
  }

  "A InternalConvLSTM2D " should "work with valid padding" in {
    val hiddenSize = 5
    val inputSize = 3
    val seqLength = 4
    val batchSize = 1

    val input = Tensor[Double](batchSize, seqLength, inputSize, 16, 16).rand()

    val rec = new InternalRecurrent[Double]()
    val model = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](inputSize, hiddenSize, 4, 1,
          padding = 0, withPeephole = false)))

    val output = model.forward(input)
    model.backward(input, output)
  }

  "A InternalConvLSTM2D " should "generate corrent output with same padding" in {
    val hiddenSize = 5
    val inputSize = 3
    val seqLength = 4
    val batchSize = 1

    val inputData = Array(
      0.7858528697343803, 0.39643673229432663, 0.5413854281649032,
      0.19631629573541698, 0.9680900797738962, 0.7049105791704678,
      0.718163522037053, 0.3668819786771701, 0.9440464687074629,

      0.397055650490982, 0.5690454413390856, 0.18889661397871826,
      0.001579648466406347, 0.7021713615013074, 0.8454086571332006,
      0.35854686871163577, 0.2314278384571361, 0.13381056284442328,

      0.6132205837353593, 0.7805923402713709, 0.3193504762780809,
      0.4820660231781917, 0.13430830215086442, 0.08936018148523917,
      0.4341675120170957, 0.3757014153775321, 0.38254346806277906,

      0.21915241667840835, 0.8007170972591638, 0.15182768990111906,
      0.16921631594179842, 0.4634274688062553, 0.8563216662553195,
      0.7757094855568256, 0.8645676388878089, 0.2951255327230293,

      0.8551730734700937, 0.4644185607835579, 0.871393203268247,
      0.4597647143442395, 0.4071171727330902, 0.4854337817083433,
      0.3318299516855079, 0.498443023507131, 0.7156391293365701,

      0.3249206769911145, 0.029879345424368986, 0.9096393017575739,
      0.2144207018779739, 0.0503060117834595, 0.089917511783566,
      0.6167131506276432, 0.8926610118451324, 0.22789546589690257,

      0.8653884115886622, 0.2722629410529933, 0.8464231593999603,
      0.1986656076231056, 0.36207025934462145, 0.20100784318653586,
      0.49658835737109863, 0.8639942640204804, 0.6467518899038553,

      0.4596178558550218, 0.47357079720990714, 0.38797563894307674,
      0.9548503051587474, 0.5749238739873289, 0.3847657155455617,
      0.3841002078008404, 0.6294462025466693, 0.011197612494351583,

      0.5527740169991779, 0.22146120004645264, 0.5994096770862478,
      0.31269860199224186, 0.34076707800851624, 0.059283360194189005,
      0.8357335896530725, 0.16944373713090255, 0.8659253353994816,

      0.09845143876193818, 0.11288714233839348, 0.3621625765077783,
      0.7427050848065528, 0.32484987192893533, 0.4459479804922727,
      0.6566592559514339, 0.9947576704428358, 0.7117063859310553,

      0.7744659218984301, 0.23830967806526449, 0.4276461743258111,
      0.44676684424295, 0.820457223019586, 0.20418080384634774,
      0.5640742319761487, 0.25500989650402617, 0.5606232230408129,

      0.9632661452404038, 0.7091629761784937, 0.2987437440895777,
      0.42183273373923236, 0.13458995231781568, 0.8353811682997592,
      0.4543100021283092, 0.13410281765513044, 0.678439795677715
    )
    val input = Tensor[Double](inputData, Array(batchSize, seqLength, inputSize, 3, 3))

    val rec = new InternalRecurrent[Double]()
    val model = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](inputSize, hiddenSize, 3, 1, padding = -1,
          withPeephole = false)))

    val weightData = Array(
      0.1323664, 0.11453647, 0.08062653, 0.12153825, 0.09627097, 0.09425588, -0.12831208,
      0.1123728, 0.12671691, 0.16061302, -0.110098146, 0.17929439, -0.09391218, 0.03184388,
      0.16272181, -0.17903298, -0.11212616, -0.16194142, -0.07123136, 0.171485, -0.06678299,
      0.14097463, -0.076257214, -0.12772666, 0.10127824, 0.14800657, -0.17392491, 0.182817,
      -0.12528893, 0.18364896, -0.14304572, -0.17017141, 0.09545128, 0.0021875147, -0.12588996,
      -0.17018303, -0.13192223, -0.08928953, 0.083996765, 0.113029376, -0.03533393, 0.13787921,
      0.044130187, -0.05275204, 0.06665044, 0.10583809, 0.14640503, -0.0825495, 0.175358,
      -0.17123815, 0.016984237, -0.15625823, 0.060610376, -0.07900191, -0.1368149, -0.14521007,
      0.1376833, -0.15550166, -0.10907967, 0.015855903, -0.16790473, -0.16079058, -0.17573304,
      -0.093626104, -0.14478739, 0.06643916, 0.080685385, -0.048617058, 0.12523253, -0.06701434,
      0.016185585, 0.108312786, -0.1217048, 0.18999666, 0.12596558, 0.1527906, 0.11697532,
      0.11283574, 0.057593197, 0.12615764, 0.072295435, 0.025243226, -0.11434275, -0.014110628,
      0.16677082, -0.17618565, -0.020268144, 0.017786589, 0.10867132, -0.01608419, 0.15298566,
      -0.17555264, 0.13313031, -0.10074914, -0.034776952, -0.02424163, 0.06472268, 0.014123779,
      -0.046546947, -0.018097628, 0.12484821, 0.12694938, 0.08661913, -0.028283825, -0.17283146,
      -0.11884577, 0.054063573, -0.035790008, -0.048655663, -0.10025538, -0.08609334, 0.13622199,
      -0.018927833, -0.13733545, 0.06963552, -0.14686611, -0.12066068, -0.041992664, 0.1182357,
      -0.027649706, 0.18181281, 0.041711636, 0.04432702, -0.116909824, -0.17462158, -0.021893747,
      -0.15286785, -0.17956264, 0.08781697, 0.18018082, -0.12757431, 0.15001388, 0.06699571,
      -0.013572012, 0.011313009, -0.17503677, 0.16929771, -0.1792189, -0.06653591, -0.018526785,
      0.14020926, -0.14813682, -0.042868182, 0.071334295, -0.0331954, -0.022251781, -0.044458825,
      -0.108357064, -0.09425906, 0.12529904, 0.053906064, -0.131506, 0.035108663, -0.027024692,
      0.08062959, -0.07704205, -0.08997438, 0.049663007, -0.0602714, -0.099638045, 0.11657174,
      -0.14531808, -0.14385863, -0.13701083, 0.10701223, 0.07447342, 0.02098617, 0.013973343,
      0.016073138, 0.1054289, -0.12807003, 0.10496164, -0.08306444, -0.07314216, 0.00746669,
      0.13380568, 0.11950841, 0.012040939, -0.08963716, 0.029381294, 0.0444725, -0.14176999,
      -0.046013553, 0.14456055, -0.020257693, 0.11915037, -0.13432308, -0.04188419, -0.14749014,
      0.10679834, -0.008356104, -0.091679156, -0.052243225, -0.12075397,
      -0.122984625, -0.0019464424,
      -0.009337306, -0.09842576, -0.05086776, -0.13618909, 0.044009373, 0.03276076, 0.14773913,
      -0.08475516, -0.035664886, 0.13492987, -0.0016442818, 0.01504501, 0.10326427, 0.14080256,
      0.13216749, -0.13468772, -0.13552676, -0.07458449, 0.116350845, -0.017945917, -0.079822004,
      0.14235756, 0.108743496, 0.056373693, 0.106244266, 0.13354784, -0.055554073, 0.06573345,
      0.11254954, 0.030720904, -0.12789108, 0.11583857, 0.01126726, -0.076339975, -0.043077517,
      0.14489457, -0.061520822, -0.084330715, -0.13341874, 0.08033921, 0.01965443, -0.028425984,
      -0.10873458, 0.0019523608, -0.14760552, -0.0715915, 0.121757604, 0.06075947, 0.024467388,
      0.08100012, -0.08324363, -0.077451654, -0.14260122, -0.0990978, -0.035064414, -0.122219145,
      0.075672865, -0.097894855, -0.032138903, -0.14660397, -0.09522337, -0.10158224, 0.02410856,
      0.09243034, -0.0098140575, -0.1458175, 0.1018263, 0.072414316, -0.025606995, -0.09973119,
      -0.07166437, 0.024537396, -0.104933284, -0.12598962, 0.073707975, -0.11742054, -0.117542766,
      -0.11412587, -0.11182065, 0.001858102, -0.11021746, 0.14775929, -0.05255193, 0.11530138,
      -0.053862877, -0.12126004, -0.039800987, 0.12030308, 0.06380281, 0.14243272, -0.105793275,
      -0.033665083, 0.08111985, -0.013766574, 0.113472275, -0.084296875, -0.08927287, 0.11604687,
      -4.435518E-4, 0.04164252, -0.10153338, -0.0691632, -0.1405299, 0.045724023, -0.064101756,
      0.040309414, 0.09550836, -0.1334707, -0.048097685, -0.10702615, 0.032742012, 0.11855929,
      0.037364442, -0.01936604, -0.059911903, -0.0876515, 0.0025609094, 0.013815972, 0.0032597338,
      0.048198074, -0.12489824, -0.050333824, 0.043307517, 0.075850576, 0.04691248, 0.14177343,
      0.12245611, 0.041173577, -0.04838414, -0.12059745, 0.034499634, -0.13209838, 0.12898007,
      0.09691984, -0.06896793, -0.09984528, 0.118018836, -0.13144998, -0.12338916, 0.012128665,
      0.10709243, 0.057155527, 0.11058617, -0.042369995, 0.094536856, 0.09929702, -0.04074019,
      -0.09469174, -0.014339956, -0.14446115, -0.06111967, -0.07840543, 0.08384239, 0.04264625,
      0.0666282, 0.12344897, -0.07968074, -0.11735667, -0.024460025, -0.076237164, 0.04076042,
      -0.05550425, -0.03196365, -0.14604644, 0.06142091, 0.057403088, -0.076850355, 0.04749991,
      -0.14173616, -0.12143371, -0.11346632, -0.40388778, 0.41770256, -0.16137226, 0.18514906,
      0.17749049, 0.17038673, 0.0525878, 0.059497084, -0.12689222, 0.0934659, -0.16460975,
      0.0685503, -0.1598997, 0.18707295, -0.18101873, 0.09334456, 0.0068138484, 0.119945474,
      -0.15793812, 0.16077317, -0.12223828, 0.14830972, -0.08173027, -0.14720327, 0.15735061,
      -0.18086493, -0.15324399, 0.041607905, 0.08446815, -0.066285975, -0.021721164, 0.122331254,
      -0.091555275, -0.009002679, 0.05552557, -0.12491941, -0.11549681, -0.11587785, 0.07597942,
      0.12365428, 0.18181923, 0.10287596, -0.110499285, 0.1069625, -0.15017055, 0.03847431,
      -0.10492084, 0.15501796, -0.17868865, -0.15074125, -0.14290886, -0.11815483, 0.07344515,
      0.07833582, -0.018226074, 0.18693122, -0.042623483, -0.08500967, -0.17536701, 0.03428398,
      -0.13042022, 0.106423296, -0.11648117, 0.061523918, 0.05105374, -0.14420861, -0.091228865,
      -0.09801924, 0.14349346, -0.06156704, -0.007966236, -0.07721831, 0.118240215, 0.17742375,
      -0.124138705, 0.15581268, -0.123607285, 0.06057967, -0.14400646, 0.11283668, -0.13120261,
      -0.014970919, -0.0342267, 0.08669027, -0.14633434, -0.0027491183, -0.18417189, -0.10101288,
      -0.08966918, 0.18372758, 0.12753725, -0.12904254, -0.14902571, 0.05452875, 0.04595482,
      0.0018077934, 0.06402101, -0.09695311, -0.031936277, -0.0836957, 0.09798136, -0.10965453,
      -0.15086654, -0.15770726, 0.13075432, 0.04373129, 0.17452799, 0.15790863, 0.1877846,
      0.055126328, 0.06072954, -0.06788629, -0.03286624, -0.08618504, 0.07507046, -0.06212962,
      0.1620441, -0.1030619, -0.09856772, -0.18165682, 0.118544854, 0.08244241, -0.08935274,
      -0.10706232, -0.097339824, 0.12492477, -0.14299503, -0.1733319, 0.046472415, -0.17566924,
      -0.13683586, -0.064733684, -0.13249831, -0.00250903, -0.1423295, 0.19028416, -0.06298632,
      -0.11775568, -0.15973037, -0.16295016, 0.091037214, -0.021105299, 0.114337094, -0.028981507,
      -0.056668505, -0.065267414, -0.045449473, 0.07244931, -0.12886764, -0.03691907, 0.094375856,
      -0.019792939, 0.024737319, 0.09090861, -0.12854996, 0.072243236, -0.030617973, -0.094412014,
      -0.042138487, 0.14073718, -0.101585604, 0.11363049, -0.07387862, -0.07795169, -0.057193495,
      -0.11206202, -0.042421576, -0.07356668, 0.0035800182, 0.14250474, 0.035420645, -0.09439937,
      -0.00454818, 0.04039183, 0.063089624, 0.025078747, -0.08424321, -0.13109599, -0.054199703,
      0.08430544, -0.07249277, 0.05265952, 0.09668127, 0.08280557, 0.102784894, 0.09898424,
      0.027891688, -0.00482123, 0.083571374, 0.13162544, 0.13045219, -0.13977656, 0.09740613,
      -0.09737305, -0.03441763, -0.0721285, -0.04417562, 0.013421286, 0.011711505, 0.09309547,
      0.05629748, 0.051518198, -0.099549964, 0.026949838, 0.11573528, 0.059821997, -0.086777166,
      0.043824308, 0.06498698, 0.025848297, 0.055641714, -0.09549449, -0.13476206, 0.044260006,
      0.059218522, -0.07706229, 0.06561182, -0.00923181, 0.08157991, -0.14256369, -0.021998769,
      5.660323E-4, 0.100776434, -0.08302952, -0.08642054, -0.049740147, 0.12746517, -0.067147635,
      0.050859027, -0.07180711, -0.08228902, -0.077878885, 0.07168357, 0.059669446, -0.08330871,
      -0.045181356, -0.08571539, -0.033329267, -0.12251607, 0.1383708, 0.0305497, -0.09032205,
      -0.14507341, -0.017640868, -0.11075582, 0.12985572, 0.055534806, 0.08301828, 0.003015704,
      -0.13372704, 0.13678057, 0.06473821, 0.04450592, 0.04828933, -0.09687717, 0.046958666,
      -0.13321148, -0.092644796, 0.11233321, 0.10518413, -0.065738074, -0.034296088, 0.1480033,
      0.016602587, 0.014963573, -0.08498246, 0.09283064, 0.022058068, -0.03559866, -0.12306792,
      -0.0899152, -0.13637953, 0.13551217, -0.0945891, 0.12492594, -0.10224763, -0.13880855,
      0.08347999, 0.12732211, 0.05273699, -0.09553963, 0.04250221, 0.026436547, -0.06730121,
      -0.03517642, 0.089389004, 0.032738786, -0.01546961, -0.12659001, 0.09378276, 0.10204366,
      -0.0807472, -0.14217609, -0.10287603, -0.060035314, -0.081469364, 0.14679892, 0.010581807,
      -0.026570352, 0.02302546, -0.028409055, 0.038031608, 0.010894875, -0.14031309, -0.0054970854,
      -0.08796428, -0.06273019, -0.09651355, -0.024267742, -0.062752575, -0.02792022, 0.057130355,
      -0.13787158, 0.14674985, -0.049736664, -0.018737877, 0.06715224, -0.1446277, 0.0037571336,
      0.07341689, 0.14253464, -0.14391004, -0.04377491, -0.06568719, 0.030192256, -0.020413188,
      0.068322495, -0.10879119, 0.057932742, 0.13234846, -0.109634556, 0.082434475, 0.012896727,
      0.07257251, 0.06956354, 0.14261313, 0.10148351, -0.012789844, -0.06141955, -0.14040212,
      -0.12688103, 0.04696185, -0.02195622, -0.13131708, -0.13735592, 0.017267786, -0.11317883,
      0.058699675, 0.068665326, 0.046896778, 0.0686829, 0.0040637283, 0.051919427, 0.012295667,
      0.01212392, 0.13709806, -0.027172498, 0.0227216, -0.019342942, -0.118283324, -0.12041302,
      -0.11384692, 0.068581454, 0.13964377, 0.019542517, 0.07451984, -0.03907243, 0.119501606,
      0.15579715, 0.041966874, -0.269146, 0.29396367, -0.031004338, 0.11562358, 0.10148933,
      0.01713493, 0.11109411, 0.12588945, -0.1567852, 0.007233048, -0.04118252, -0.13922743,
      0.013768866, 0.14614376, 0.03159147, 0.09009805, -0.15100884, -0.045039184, -0.0754665,
      0.16305731, -0.10817033, 0.066871606, -0.0444717, 0.0176596, -0.14596693, 0.17050807,
      0.1107248, 0.045592476, -0.12141662, -0.08798046, 0.08681701, -0.12117194, -0.032277536,
      -0.085312314, -0.104780264, 0.07352057, -0.0038153927, 0.016070124, -0.018401293, -0.08902605,
      0.14795774, 0.076488174, -0.16232127, -0.049365066, 0.18580824, -0.14259012, -0.10463519,
      -0.1639147, 0.005918354, 0.1515928, -0.107107915, -0.030828143, -0.1123979, 0.04319057,
      -0.121740706, -0.08577948, 0.12570882, 0.13560909, 0.18987843, 0.12616682, -0.079119384,
      -0.050766237, 0.020498704, -0.12197697, 0.08224279, 0.062052865, 0.02906278, 0.017115045,
      -0.18389794, 0.12766539, 0.03930737, 0.008523566, 0.18270826, 0.14880183, 0.10221764,
      -0.15133749, 0.082017995, -0.0130814025, -0.027355297, -0.10657773, 0.052748024, -0.115928516,
      -0.046661552, -0.12882695, -0.029029572, -0.06530424, 0.04025339, 0.16088219, -0.10383346,
      0.082959786, 0.0887459, 0.035165705, 0.15808243, 0.16269772, -0.1357966, -0.10179971,
      -0.0051135537, 0.17215031, 0.1328058, -0.01860083, 0.1615801, 0.12609202, 0.16104408,
      -0.021468738, -0.16631302, 0.0871791, 0.014376062, -0.17092176, 0.11434908, 0.122610696,
      0.028687527, 0.050493836, 0.07100734, 0.059703957, -0.082912676, -0.12887894, -0.04537745,
      -0.18151796, 0.056237914, 0.14380231, -0.14045763, -0.09534631, -0.074677795, 0.19009903,
      -0.10265426, -0.1292793, -0.11336175, -0.1784303, -0.07473458, -0.15607946, 0.07124518,
      -0.00909464, -0.045578852, 0.19244981, -0.15921071, 0.07649794, 0.17505349, -0.18314067,
      0.07524149, 0.010540678, 0.09569251, 0.046074994, -0.02197193, 0.086423144, 0.09348795,
      -0.028091308, -0.09014368, -0.10826236, 0.13677588, 0.04464668, 0.034753967, -0.12927626,
      -0.11813127, 0.004072468, 0.026979715, -0.03775997, 0.061768685, -0.08796178, 0.11295643,
      -8.3927833E-4, -0.014426978, -0.11624958, 0.034711517, -0.08129607, 0.03957355, -0.12502019,
      -0.030620545, 0.09466354, -0.11467053, 0.0122420285, -0.055959523, 0.018059604, -0.14006427,
      -0.08326284, 0.041128356, 0.13832542, 0.015870363, -0.08918273, 0.08830697, -0.107775986,
      -0.07253431, -0.040891536, -0.13261265, 0.09009444, 0.14417978, 0.10808684, -0.037405483,
      -0.034763217, 0.050338723, -0.06079984, 0.11137182, -0.014334275, 0.102414526, 0.07237307,
      0.0776832, 0.11163047, -0.04197717, -0.12803015, 0.024760138, 0.0034540703, 0.0042343847,
      0.07084355, -0.06716749, -0.102060705, -0.135623, 0.06517094, 0.1303601, -0.0050743595,
      -0.02650759, -0.02323579, -0.08093262, 0.05246471, -0.13667563, -0.020750018, -0.04469017,
      0.09838917, -0.13269165, -0.122779235, -0.01706858, 0.10670577, 0.054394018, -0.086262256,
      -0.08817667, -0.062382106, 0.08033159, 0.031420924, -0.088938475, 0.12926981, 0.08264999,
      -0.11417435, 0.08372043, 0.0063608363, -0.021203713, 0.014786269, -0.07409301, -0.021681488,
      0.08935846, 0.058311924, -0.082670316, 0.06871885, 0.116076715, -0.08725295, -0.14136109,
      -0.026386688, -0.061878256, 0.14713068, 0.022175662, -0.041593164, 0.14277703, 0.028373852,
      0.040583156, 0.111651644, -0.016684225, -0.10156425, 0.0021935678, -0.14317933, 0.061549712,
      0.13288979, 0.14676924, -0.101627015, -0.07396209, -0.13494837, -0.09710303, 0.054546677,
      0.08524791, -0.08548785, 0.035400804, -0.042056955, -0.08300001, -0.1220187, -0.14316916,
      0.080142096, 0.04776946, 0.0020199257, -0.109317854, 0.08319731, -0.14316626, 0.0773201,
      -0.0937432, 0.067408286, 0.14834292, -0.13614324, 0.13782948, -0.13609166, 0.119039305,
      -0.013118409, -0.07150358, 0.077627346, 0.09797143, -0.06394284, -0.047202796, 0.047071274,
      0.07730204, -0.046518948, 0.05054308, 0.02134939, -0.08509347, 0.13650805, -0.117712826,
      0.102454275, -0.12443718, -0.043289855, -0.02050105, 0.011652269, -0.13120848, 0.096573845,
      -0.08977994, -0.12086315, 0.063193254, -0.004018177, 0.13031282, 0.08007126, -0.027115893,
      -0.09841936, 0.1075952, 0.0168945, 0.023779849, -0.049907062, -0.085106626, 0.0026978212,
      0.032795515, -0.07142712, -0.08590961, -0.09582995, -0.13937126, 0.108712606, -0.14322321,
      0.070150875, -0.13415396, 0.013311513, 0.09438836, -0.13973792, 0.013226105, 0.014542425,
      -0.027073216, -0.023348909, -0.012271121, -0.082431994, -0.0440865, 0.032611705, 0.10746979,
      0.039494224, -0.09831236, 0.05940311, 2.3674325E-4, -0.013012138, -0.08452134, 0.026267232,
      0.055084113, -0.007533836, -0.13804708, -0.07745523, 0.055061925, -0.13331874, 0.055787712,
      0.14850645, 0.09001049, -0.113049194, 0.1034078, -0.053732794, 0.005184802, -0.10854875,
      -0.032727003, 0.069370605, -0.0699945, -0.06817097, 0.10078447, 0.07813147, 0.033056628,
      -0.1704846, -0.15495631, -0.066432096, -0.13988744, 0.019557567, -0.06827179, 0.010410018,
      0.11517244, 0.06752451, -0.07570038, -0.069486514, -0.01585388, 0.13004497, 0.05301245,
      -0.04380925, -0.057413254, 0.18796284, -0.025588196, -0.055301014, -0.18797436, 0.042079464,
      0.049512137, 0.036646772, -0.028730556, 0.011647098, 0.12882207, -0.17270355, -0.0035952448,
      0.1648032, 0.031258453, 0.06238857, 0.021837315, -0.021809775, -0.013917966, -0.13398269,
      -0.09437768, -0.17708766, 0.05525659, -0.07333484, -0.08282379, 0.11259561, -0.021445781,
      -0.054697804, -0.14095771, 0.16938439, -0.08627576, -0.05361658, -0.12766492, 0.17628285,
      0.10054519, 0.11889319, 0.05916385, 0.17565647, 0.17474626, -0.084707394, -0.06529905,
      0.027251877, 0.045083825, -0.068300545, 0.07404977, -0.106370196, -0.18558562, 0.0012421582,
      -0.07639284, 0.09119752, -0.16026673, 0.03536145, 0.0415398, 0.14838776, -0.08839314,
      -0.095788084, -0.07995972, 0.03311477, -0.110495165, -0.12813115, 0.014271884, 0.09901724,
      -0.1511705, 0.008499655, -0.15553316, 0.098703355, 0.1498216, 0.076806694, 0.006844005,
      -0.109560326, 0.090494744, 0.14271256, -0.18168364, 0.18955807, 0.01633289, -0.18395755,
      0.114932604, 0.10402427, 0.038280725, -0.17810583, 0.01789274, -0.048914507, 0.09589668,
      -0.06169789, 0.12770708, -0.18387029, -0.04703918, 0.17488319, -0.12249794, 0.060243525,
      -0.056309838, -0.032533053, -0.1619768, 0.17675658, 0.014305584, 0.14060433, -0.042199265,
      0.05712904, 0.044829167, 0.011096618, -0.14014772, -0.17880245, 0.16239895, 0.1751847,
      0.13345739, -0.087636806, -0.092316225, -0.027283559, -0.07892386, 0.18008122, 0.053961694,
      -0.16841035, -0.11244774, -0.18884337, 0.19002426, -0.12775706, -0.0441988, -0.064007156,
      -0.006827585, -0.034391925, 0.017185388, 0.14070128, 0.18855122, 0.16581473, -0.07020123,
      -0.05308279, -0.10519481, 0.14562331, 0.10377785, 0.048685696, -0.03890746, -0.12079896,
      -0.050986253, -0.062715515, -0.072706096, -0.12474842, -0.03164121, 0.018972829, 0.09503959,
      0.03272181, 0.0718327, 0.07293465, 0.0356435, 0.0070007364, 0.030311415, -0.028329633,
      -0.07726967, 0.02418142, 0.1357774, 0.14425409, 0.042558707, -0.061877035, 0.13747448,
      0.13204801, 0.13962367, 0.050360594, 0.11418883, 0.12632264, -0.06569335, 0.120681286,
      0.07733086, -0.029515827, -0.13601883, -0.017011724, -0.010782417, 0.014715333, -0.052253734,
      -0.032368187, 0.13431908, -0.043660387, 0.096442826, -0.13594243, 0.09075854, 0.0146357855,
      0.061830286, 0.038969733, 0.112418495, -0.04555642, -0.12400285, 0.094084635, 0.0386825,
      0.05185844, 0.07179485, 0.12685227, -0.07149219, -0.09555144, -0.12677261, 0.014616831,
      0.11663383, -0.14626648, 0.12032344, 0.052169777, -0.14531699, -0.14825104, 0.0458667,
      0.14745489, -0.11824784, -0.143719, -0.032160666, 0.013308394, -0.038604915, 0.046203095,
      -0.061285798, -0.011695573, 0.05354165, -0.0021355557, 0.03226571, -0.13687798, 0.07708235,
      0.012822559, 0.13057984, -0.047273517, -0.046722367, 0.13640164, -0.1134389, -0.018769452,
      0.033678796, 0.10306464, -0.14307259, 0.08554146, 0.05952141, 0.0710592, -0.07967377,
      -0.110443026, -0.051775664, 0.08555203, 0.12498455, 0.010950238, 0.07806091, 0.02533094,
      -0.12941347, -0.0676599, 0.047001403, 0.13817227, -0.041809034, -0.039458882, -0.1277546,
      -0.09346504, 0.045584284, -0.013046625, 0.13334489, -0.030883985, 0.06816985, -0.094920315,
      8.6036685E-4, -0.063012615, 0.022756761, -0.14304605, 0.0018345796,
      -0.027446955, -0.062981054,
      0.049244262, 0.08467103, 0.04948917, 0.12852058, -0.054968648, -0.0013518144, 0.08650799,
      0.046505045, 0.12778473, -0.04336263, 0.12158557, -0.048306707, -0.07632669, -0.118921325,
      0.08370227, 0.10476553, 0.1153842, 0.076464675, 0.077894576, 0.11230735, -0.006782354,
      0.043779783, 0.124212556, -0.06815983, 0.07183016, 0.103715405, 0.0060186437, -0.07855222,
      -0.068271, -0.11873516, -0.03313894, -0.12967804, 0.034167856, -0.0710432, -0.13800886,
      -0.06777998, 0.0073924917, 0.07163505, 0.05019128, -0.10930661, -0.061820634, -0.016123764,
      0.05254672, -0.06823265, 0.1219117, 0.016617121, -0.07891416, 0.034150273, 0.122803226,
      0.06793002, 0.09735866, -0.02079113, -0.13155958, 0.13877551, -0.10352277, 0.044731736,
      0.13054536, -0.0034577732, -0.10844458, 0.1308905, 0.1355866, -0.1172533, 0.13036548,
      -0.06841189, -0.14214517, 0.029879646, 0.14223722, -0.09950151, 0.0975786, -0.08808848,
      -0.102723174, 0.04360433, 0.043727856, -0.102559224, -0.068816505, 0.14440708, 0.038142376,
      0.11174967, -0.13707735, 0.11471822, 0.07180096, -0.089118496, -0.0041832873, -0.06590406,
      -0.06827257, 0.115748845, 0.14643653, -0.13591826
    )
    val weights = model.getParameters()._1
    weights.copy(Tensor[Double](weightData, Array(weightData.size, 1)))

    //    val weightsOri = new ArrayBuffer[Tensor[Double]]()
    //    val weightsNew = new ArrayBuffer[Tensor[Double]]()
    //
    //    val sizeI = hiddenSize * inputSize * 3 * 3
    //    val sizeH = hiddenSize * hiddenSize * 3 * 3
    //    var next = 0
    //    for(i <- 0 until 4) {
    //      val i2g = Tensor[Double](weightData.slice(next, next + sizeI),
    //        Array(1, hiddenSize, inputSize, 3, 3))
    //      weightsOri += Tensor[Double]().resizeAs(i2g).copy(i2g)
    //      next += sizeI
    //      val i2gBias = Tensor[Double](weightData.slice(next, next + hiddenSize),
    //        Array(1, hiddenSize))
    //      weightsOri += Tensor[Double]().resizeAs(i2gBias).copy(i2gBias)
    //      next += hiddenSize
    //      val h2g = Tensor[Double](weightData.slice(next, next + sizeH),
    //        Array(1, hiddenSize, hiddenSize, 3, 3))
    //      weightsOri += Tensor[Double]().resizeAs(h2g).copy(h2g)
    //      next += sizeH
    //    }
    //
    //    // weightsOri(0) -----> forgetGatei2g.weight
    //    // weightsOri(3) -----> inputGatei2g.weight
    //    // weightsOri(6) -----> hiddeni2g.weight
    //    // weightsOri(9) -----> outputGatei2g.weight
    //    val weightsTable = T(weightsOri(0), weightsOri(3), weightsOri(6), weightsOri(9))
    //    val joinWeights = JoinTable[Double](2, 5)
    //    weightsNew += joinWeights.forward(weightsTable)
    //
    //    // weightsOri(1) -----> forgetGatei2g.bias
    //    // weightsOri(4) -----> inputGatei2g.bias
    //    // weightsOri(7) -----> hiddeni2g.bias
    //    // weightsOri(10) -----> outputGatei2g.bias
    //    val biasTable = T(weightsOri(1), weightsOri(4), weightsOri(7), weightsOri(10))
    //    val joinBias = JoinTable[Double](1, 1)
    //    weightsNew += joinBias.forward(biasTable)
    //
    //    // weightsOri(2) -----> forgetGateh2g
    //    // weightsOri(5) -----> inputGateh2g
    //    // weightsOri(8) -----> hiddenh2h
    //    // weightsOri(11) -----> outputGateh2g
    //    weightsNew += weightsOri(2)
    //    weightsNew += weightsOri(5)
    //    weightsNew += weightsOri(8)
    //    weightsNew += weightsOri(11)
    //
    //    weights.copy(Module.flatten[Double](weightsNew.toArray))

    val output = model.forward(input)
    val gradInput = model.backward(input, output).asInstanceOf[Tensor[Double]]

    val expectedGradData = Array(
      0.014530784777034018, 0.009306995336246415, -0.008286519510758553,
      -0.005862022951432651, 0.03774343783342986, 0.02559209113219452,
      -0.009158373635152154, -0.006768851798254148, 0.02421276189520975,

      0.019129483877254584, 0.03387074024263973, 0.016151513843310296,
      0.009110519248782582, 0.0519530219650559, 0.02990668621337847,
      0.0038953190691805974, 0.00881452262673705, 0.006641299351351705,

      0.011895276575900459, 0.010718002608531207, 5.692818102669443E-5,
      5.613122237450347E-4, 0.004340609346706831, 0.009249622360370996,
      0.007077281923732542, 0.025274489209446492, 0.030983433996406604,

      0.01550887654620608, 0.015250729260157744, -5.386170921407475E-4,
      -0.013094061514047645, 0.03932976633855383, 0.016212649176254017,
      -0.014008845387900731, -9.695636289087138E-4, 0.013443209145444692,

      0.021871534544684157, 0.04383469258878202, 0.030960134463390295,
      0.025071766698666193, 0.048735184244840295, 0.030630173082074514,
      -0.005170600367028816, 0.009304571633078955, 0.014833103616387424,

      0.010256073366121572, 0.009494353388598848, 0.010272214821693637,
      0.0024692527383167804, 0.004536446048889638, 0.012244081222251282,
      0.010526712698432203, 0.033398738011474687, 0.02813455387980294,

      0.010066005831263696, 0.0019506300334422779, -0.0019043398614618512,
      -0.004088008352356501, 0.024485482387752576, 0.00699954945689254,
      -0.01904707399526851, 0.0033876388201005622, 0.017088928573562367,

      0.017872693491787887, 0.041176979595515396, 0.022870978612379685,
      0.017734548112707957, 0.055116178410244576, 0.02362922566039578,
      -0.0032533670131090326, 0.014820133977046292, 0.015189531245204876,

      0.013417026107347331, 0.008199462128197013, 0.008682901874796495,
      -0.001057281436202452, -9.580529058057278E-4, 0.018130305232989127,
      0.013197627737917545, 0.025993499815504965, 0.0282101706655173,

      0.004722146118847874, -6.211365954561445E-4, 0.001463247132531586,
      -0.0040490572380433, 0.009980156570398933, 0.003884023741930937,
      -0.008444307994339605, -0.003627754988835289, 0.014385061921127408,

      0.007966076755923803, 0.030602239509086984, 0.016312201377609553,
      0.015614085801242397, 0.03735781819802621, 0.011600928393055569,
      -0.00961434081421862, 0.010773314255620866, 0.007550266521377238,

      0.006972407988187046, 0.005173251794764602, 0.004780105131762285,
      -5.812169827385085E-4, 6.467089566629608E-4, 0.011314282674270778,
      0.009924497929978446, 0.01018491048439032, 0.022783415058596372
    )

    val expectedGrad = Tensor[Double](Storage(expectedGradData), 1,
      Array(batchSize, seqLength, inputSize, 3, 3))

    expectedGrad.map(gradInput, (v1, v2) => {
      TestUtils.conditionFailTest(abs(v1 - v2) < 1e-6)
      v1
    })
  }

  "A InternalConvLSTM2D " should "generate corrent output when batch != 1" in {
    val hiddenSize = 4
    val inputSize = 2
    val seqLength = 2
    val batchSize = 3

    val inputData = Array(
      0.7379243471309989, 0.19769034340444613, 0.7553588318460729, 0.04613053826696778,
      0.5252748330906923, 0.3243104024273151, 0.8989024506441619, 0.6712812402234619,
      0.15323104849092073, 0.10185090916201034, 0.5226652689296567, 0.2899686031416233,

      0.511326269557182, 0.49001743514565177, 0.1984377134313693, 0.7480165633318978,
      0.6211121942699456, 0.6036424737237795, 0.1592963148948352, 0.26366166406420644,
      0.7005873221417939, 0.35208052461842276, 0.725171953811403, 0.5044467614524997,

      0.523452964024349, 0.7937875951399342, 0.6943424385328354, 0.8522757484294406,
      0.2271578105064339, 0.5126336840147973, 0.6541267785298224, 0.3414006259681709,
      0.4863876223334871, 0.9055057744040688, 0.8798793872101478, 0.41524602527888876,

      0.8876018988857928, 0.2600152644808019, 0.8634521212917783, 0.2153581486704762,
      0.4154146887274315, 0.89663329916328, 0.5406127724020381, 0.5893982389701549,
      0.19306165705001244, 0.45133332543350857, 0.9714792090557134, 0.38498402005236265,

      0.006891668205946, 0.5828297372808069, 0.5688474525455758, 0.8823143009035355,
      0.6887655121242025, 0.0027582524956907273, 0.9663642226218756, 0.4108958429337164,
      0.13044029438378613, 0.7476518446247042, 0.27181284499242064, 0.2943235571195709,

      0.5604631859123652, 0.038323627725658116, 0.7603531593097606, 0.27194849168888424,
      0.7267936284527879, 0.9052766088556446, 0.02837551118315562, 0.4574005827539791,
      0.7973179561393601, 0.9980165989408478, 0.25069973544723445, 0.7113158573554784,

      0.11561652478927409, 0.6416543947530068, 0.4621700648670196, 0.029251147823290524,
      0.43078782880788635, 0.25593904336252427, 0.7192673207531932, 0.1756512226290753,
      0.8578302327025378, 0.7098735179715736, 0.5881799635556655, 0.3308673021852163,

      0.6190265867439853, 0.04932089904003756, 0.6528462133452869, 0.9606073952644223,
      0.5669727867139613, 0.30551745863715773, 0.6397990914737653, 0.4658329782779089,
      0.5963362791762975, 0.7691666538081872, 0.3162174933237236, 0.4259981344604039,

      0.3141262752374987, 0.1977960342228966, 0.5334056420410246, 0.40966136495594996,
      0.757202109971997, 0.913316506797173, 0.6599946668051349, 0.3938817528802213,
      0.5209422302691442, 0.6422317952685995, 0.17747939441294336, 0.8074452051580658,

      0.3944410584066397, 0.2527031847980983, 0.39527449481971033, 0.0928484214334575,
      0.4852964142206313, 0.25068274731805196, 0.08329386976314157, 0.4681014271865992,
      0.860995719082069, 0.01158601045989438, 0.31712803268106715, 0.2617924245499301,

      0.934104253955566, 0.24448283400001236, 0.9842141116573005, 0.7934911952785096,
      0.17490418077955516, 0.11254306305960093, 0.20896334129728766, 0.2764503640807836,
      0.24782820561381746, 0.7664813069795481, 0.4206320455881154, 0.6306966327857411,

      0.33697752771986067, 0.7870030511736112, 0.9097909928901393, 0.41198674582123507,
      0.25120817362250203, 0.17776223144093362, 0.49757005026260037, 0.28573572832782146,
      0.8807161778981778, 0.10209784975052816, 0.37642364632639536, 0.6049083416711989)
    val input = Tensor[Double](inputData, Array(batchSize, seqLength, inputSize, 3, 4))

    val rec = new InternalRecurrent[Double]()
    val model = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](inputSize, hiddenSize, 3, 1,
          padding = -1, withPeephole = false)))

    val weightData = Array(
      -0.0697708, 0.187022, 0.08511595, 0.096392, 0.004365, -0.181258, 0.0446674,
      -0.1335725, -0.20553963, -0.06138988, -0.07350091, 0.21952641, -0.20255956, -0.010300428,
      -0.038676325, 0.1958987, -0.13511689, -0.101483345, -0.125394, 0.21549562, 0.009512273,
      0.22363727, -0.12906058, -0.1364867, -0.096433975, 0.05142183, 0.0600333, 0.104275264,
      -0.23061629, -0.14527707, 0.011904705, -0.13122521, -0.028477365, -0.17441525, -0.22748822,
      0.089772895, 0.047597, -0.12802905, 0.023593059, 0.047630176, -0.0041123535, 0.15042211,
      0.22905788, 0.16112402, 0.17447555, -0.17807686, -0.1150537, -0.112298205, -0.09767061,
      -0.06547484, -0.059544224, -0.068848155, 0.017608726, -0.0486288, 0.040728, 0.0032547752,
      -0.062271275, -0.07636171, -0.090193786, -0.1757421, -0.029310212, -0.055856735, 0.2138684,
      0.08065, -0.14784653, -0.11130823, 0.15752073, 0.0417686, 0.07667526, -0.08721331,
      0.12652178, -0.22656773, 0.028447477, 0.059678376, 0.097432226, 0.07587893, 0.018427523,
      0.07969191, -0.16021226, 0.15130767, 0.16185403, 0.09212428, -0.009819705, -0.09506356,
      0.1286456, -0.08805564, 0.046685345, 0.12232006, -0.100213096, -0.15668556, 0.07114366,
      0.095224895, -0.124568574, 0.06927875, 0.0905962, -0.035104837, 0.15567762, -0.043663125,
      -0.07430526, -0.055192, 0.01406816, -0.082332, -0.09430171, -0.1095887, -0.07151577,
      -0.02634421, -0.06150073, 0.11257642, 0.14980435, -0.062605076, -0.12603457, -0.15501663,
      0.02251482, -0.08227526, 0.13384429, 0.024052791, 0.13748798, -0.08544985, -0.036989275,
      0.026403576, -0.136615, 0.05420539, 0.16269544, 0.035181798, 0.15763186, -0.1604577,
      0.1537655, 0.02734131, -0.048081715, 0.100141905, 0.028017784, 0.08105907, 0.019959895,
      0.13988869, 0.16058885, -0.003570326, 0.06591913, -0.10256911, 0.13575412, 0.04774964,
      -0.017293347, -0.048617315, 0.15695612, 0.15765312, -0.047396783, -0.16620952, 0.0025890507,
      -0.13422322, -0.03875675, -0.075357996, 0.113039605, 0.13345407, 0.09567941, -0.003772,
      0.07441882, 0.04021747, 0.12041045, -0.042105403, -0.027613033, 0.15320867, -0.12912026,
      -0.081750855, -0.0344171, -0.15512398, 0.15219747, 0.036528654, -0.012755581, 0.098534,
      -0.07061299, 0.02883929, 0.14481406, -0.051582605, 0.10316327, 0.085615724, 0.06536975,
      -0.054357443, 0.02749899, -0.013213737, 0.057099275, 0.15802467, -0.05081968, -0.07198317,
      0.11493357, -0.0012803806, 0.11840431, 0.10919253, -0.10307259, 0.087982856, 0.06715956,
      0.03439658, -0.1251883, -0.16122015, 0.11468333, 0.15124878, -0.040252376, 0.13959402,
      -0.018218568, -0.03417238, -0.07071258, 0.15031807, -0.09312864, -0.014361585, 0.009083145,
      0.07651518, -0.030849354, 0.053097464, 0.02317304, -0.0126761, -0.10731614, 0.08843881,
      -0.058363467, -0.07192067, 0.13913071, -0.07697743, 0.15923063, -0.08419231, 0.017478677,
      -0.08418075, -0.057064693, 0.0024510117, -0.20928818, -0.22167638, -0.038345862, 0.03438525,
      0.11347725, -0.12304, 0.026768617, 0.045132592, 0.091074154, 0.13448715, 0.11804616,
      -0.22657603, -0.016182138, -0.1331919, 0.05141501, 0.015872177, -0.12630826, 0.21568011,
      -0.10292801, 0.13611461, -0.08374142, -0.22699684, 0.16571483, -0.098663375, -0.018467197,
      -0.15427141, -0.15015155, 0.10223335, -0.14016786, -0.10880828, -0.21908437, 0.14608948,
      0.07250339, 0.06662375, -0.18800929, 0.11404393, 0.13704747, 0.14116052, 0.10486333,
      -0.010664585, -0.11811825, -0.1724059, 0.15996984, -0.17623067, -0.055978876, -0.10195447,
      -0.17426933, 0.009317461, -0.23025058, -0.22655061, 0.1504219, -0.22794038, 0.19736658,
      -0.076756656, -0.16812365, -0.22800718, 0.17344427, -0.12931758, 0.10014104, -0.13550109,
      -0.23511806, -0.06651987, 0.19619618, 0.097913995, 0.14484589, -0.20718974, 0.20920905,
      0.18459858, -0.008639049, -0.22041525, 0.08012527, 0.14633249, -0.024920763, -0.18607515,
      0.07662353, -0.15454617, 0.067585476, 0.05524932, -0.15291865, 0.12737663, 0.05814569,
      -0.11415055, 0.11919394, -0.06319863, 0.1465731, 0.054857183, -0.15075083, 0.090675876,
      0.1525343, -0.0066932747, -0.048541967, 0.06132587, -0.079331905, 0.11314261, 0.14027406,
      -0.0266242, -0.016292417, 0.07795509, 0.020753743, -0.10986114, 0.10756251, 0.02036946,
      0.026220735, -0.11005689, 0.10311518, 0.07109452, -0.09970161, 0.068307705, 0.11119034,
      -0.06424175, -0.0012396448, -0.11550802, -0.06943571, -0.110153, -0.041444167, -0.12524629,
      0.15868594, 0.008897657, -0.10843479, 0.15759167, -0.09669543, -0.08299825, -0.0937801,
      -0.020804988, -0.08680972, 0.083160855, 0.029616985, -0.017982747, 0.0037287925, 0.097527005,
      0.09538205, -0.0932, -0.097054094, 0.10397664, 0.12322543, -0.06448696, -0.12847184,
      0.050058555, 0.09502069, 0.08681986, -0.14003497, 0.03627888, -0.075629145, -0.095788166,
      -0.08410784, 0.13308963, -0.007147816, -0.16363329, 0.12797672, 0.124641, -0.05630061,
      0.0064241965, 0.077181205, -0.1251426, 0.08616565, 0.1477562, -0.04511368, 0.029885028,
      0.057127535, -0.08563146, 0.13729702, -0.10859255, -0.102196366, 0.008430395, -0.0945447,
      0.10205625, -0.07343792, 0.16189432, -0.1300748, 0.08548705, 0.16390403, -0.02669807,
      -0.058629803, -0.05904906, -0.016605929, 0.14874554, 0.014934211, -0.09052281, 0.0579616,
      -0.041529182, 0.09614261, 0.15888576, -0.11366321, -0.102919176, -0.1167308, -0.011413716,
      0.07176415, -0.03216456, -0.063260436, 0.059609246, 0.16423965, 0.0052398313, 0.1286797,
      -0.0381152, -0.009582818, 0.004786132, -0.1019815, 0.043783717, 0.05244485, -0.06435464,
      -0.16259833, -0.100482024, 0.0587321, -0.052555863, 0.032503795, -0.1606384, -0.14574,
      -0.05185242, -0.08184071, 0.15766397, 0.09867271, 0.08309498, -0.15646282, 0.15911676,
      -0.008041214, -0.07785257, 0.06866316, -0.07157379, 0.13319956, -0.066218115, 0.0138255,
      -0.076073825, -0.14936924, 0.10676395, -0.4842985, 0.083836384, -0.21565008, -0.2306193,
      0.0017399695, -0.07744393, -0.2044993, -0.21714376, 0.0077707577, 0.14650588, 0.19233301,
      -0.04317506, -0.058397546, 0.15633541, -0.115028, -0.044130307, 0.063888475, 0.21123467,
      0.20039539, -0.045635425, 0.040344927, -0.12061099, 0.13238785, -0.11554383, 0.012527357,
      -0.04936022, -0.223834, 0.2067501, 0.035001267, -0.121593, 0.08469669, -0.15821323,
      0.013301196, 0.19869077, -0.18677086, -0.09790556, 0.18662173, -0.216591, 0.13041325,
      0.13628985, 0.042848308, -0.031125685, -0.22374651, -0.087204315, 0.05124186, -0.22576457,
      0.014185649, -0.0899537, -0.015126135, -0.10904176, -0.212513, -0.19453013, 0.0071554612,
      -0.07960433, -0.20750536, -0.22908148, 0.066988595, -0.11946863, -0.20373446, -0.03756,
      -0.15693687, 0.015695922, -0.19193731, 0.035843078, 0.07994549, 0.025597025, -0.10725631,
      -0.11276663, -0.1882937, 0.019561082, 0.0135140605, 0.041632164, 0.0010907603, -0.06264914,
      -0.016213655, 0.0937373, 0.094795, -0.1173104, -0.21944033, -0.09396857, 0.13556847,
      -0.09024931, 0.1276821, -2.7715965E-4, 0.12017726, 0.13998412, 0.13809435, 0.16587347,
      0.04789949, -0.08513931, 0.07294201, -0.08220003, -0.15560868, 0.14816408, -0.09582949,
      0.051776934, -0.011485172, 0.14832942, 0.10104054, 0.080303155, 0.0034141147, 0.14833276,
      0.09612207, 0.11273294, 0.13111332, -0.00879518, -0.1397018, -0.10093753, -0.00945932,
      -0.032682095, -0.14018348, 0.050238717, 0.09185889, -0.14419281, 0.09613244, -0.13719763,
      0.04358094, -0.15398286, -0.116741166, -0.11954482, 0.14914127, -0.126483, -0.026603939,
      0.15768388, 0.06356159, 0.05631903, 0.0101217795, 0.15248485, -0.14745563, -0.0145869935,
      0.0382958, 0.057202652, -0.14191794, 0.059604887, 0.011006361, -0.07016107, 0.076446384,
      0.013760659, -0.068240955, 0.0037634, 0.12695941, 0.041081227, 0.10223117, 0.11603621,
      -0.06294605, -0.010134418, -0.006934982, 0.11731349, -0.10002373, 0.14468494, 0.006046706,
      -0.11748926, -0.13269922, 0.08922616, 0.076726876, 0.079133116, -0.13795392, 0.05776867,
      -0.12632991, -0.16351144, -0.067499354, 0.047223303, 0.063164465, -0.0149828065, -0.031813424,
      -0.08393954, -0.067819, 0.081516, -0.1065244, 0.14492081, 0.11396905, -0.10664382,
      -0.0098184915, 0.08660889, -0.16464078, 0.07709077, -0.1493178, 0.017629929, 0.08108806,
      -0.057861995, 0.05144662, -0.019507658, 0.098744385, 0.14157839, -0.101155385, -0.1155548,
      -0.1539434, 0.07039324, -0.015811022, 0.15094946, -0.16115923, 0.116900794, 0.11721963,
      0.020760974, 0.0040808455, -0.0896887, 0.013347261, 0.11278092, -0.07966485, -0.094330534,
      -0.15664604, 0.015197758, 0.12119024, -0.05060158, 0.06654976, -0.13198644, -0.1457269,
      -0.13899888, 0.038908076, -0.13269305, -0.11445787, 0.021789772, 0.027084751, 0.01323522,
      -0.12667863, 0.026683968, 0.04916361, 0.0086855, 0.15367854, 0.031549584, -0.0036370864,
      0.08499007, -0.10802871, 0.03548985, -0.17660856, -0.068241306, 0.15097389, 0.16520916,
      0.024556529, 0.0017257226, 0.17331718, 0.196117, 0.19437543, 0.19648184, -0.1331118,
      -0.21632133, -0.18020143, 0.12856491, 0.1344524, 0.11382166, 0.064181186, 0.14279565,
      -0.08350899, -0.2256594, -0.13126723, 0.043258272, -0.021165192, 0.089386486, -0.09204444,
      -0.0960608, -0.037649803, 0.22336064, -0.031554904, 0.124656096, -0.025671339, -0.1065685,
      0.0453102, -1.68393E-5, 0.22479524, 0.046631828, 0.007860622, -0.22629729, -0.13721013,
      0.22810946, -0.12107487, 0.022246245, 0.17803338, 0.2083739, 0.18673882, -0.1917718,
      0.07565709, 0.120346785, -0.14759375, -0.1377154, 0.038963128, 0.22792713, -0.2159763,
      -0.006619736, 0.2313753, -0.04800687, -0.1518908, 0.18948461, 0.1076321, -0.11479616,
      -0.0212803, 0.14886868, -0.22150691, 0.089185275, -0.040394045, 0.13415302, 0.21480684,
      0.0878023, 0.106930904, -0.18570949, -0.013600573, 0.11532847, 0.11659276, 0.112827145,
      -0.1062416, 0.066263296, -0.08610482, 0.105527066, -0.058957383, -0.15528603, -0.009521967,
      0.011328606, -0.06197259, -0.13204348, 0.08675131, -0.113543, -0.01445269, 0.02258719,
      -0.008030752, -0.093486756, -0.07264881, 0.09213272, 0.07619277, 0.16032794, -0.026074272,
      0.066076815, -0.10525776, 0.16016503, 0.03144442, -0.023126643, 0.05451808, 0.022852356,
      -0.096872106, -0.030566314, -0.16589479, 0.0905115, -0.1473723, -0.12166525, 0.078377604,
      0.13821222, -0.078764655, 0.14731602, -0.08815969, -0.0236424, -0.0355236, -0.09844407,
      -0.012984, 0.047678906, -0.038449008, 0.08535368, 0.15068671, -0.008833185, -0.09007217,
      0.112541415, -0.06900989, -0.102243155, -0.050330114, -0.13928314, -0.041724514, 0.054797813,
      -0.16646549, 0.13796, 0.12394269, 0.020277899, -0.013631716, -0.09424963, -0.13880578,
      0.08686539, -0.15236098, 0.05722864, -0.02671615, -0.06085055, 0.09522983, -0.03990184,
      -0.06986189, -0.014213024, -0.1377847, 0.08251909, 0.0143873375, -0.0860864, -0.0640099,
      -0.06048214, -0.030843036, 0.10346391, -0.14285919, 0.1575129, 0.11078764, -0.09553229,
      -0.15557009, -0.039680153, -0.02489069, 0.03813003, 0.1080799, 0.07591443, 0.1631084,
      0.04714953, -0.10192201, -0.12497483, 0.038626827, -0.07361671, -0.097818114, -0.14928903,
      -0.14453772, 0.10313048, 0.11320499, -0.063832685, 0.011636197, -0.16415314, -0.142816,
      0.041214544, -0.119791135, 0.10883034, -0.14729027, -0.122481905, 0.08507194, -0.088145964,
      -0.015075706, 0.06492, -0.16094309, 0.12339206, 0.011586048, 0.1321518, -0.05177626,
      0.033773363, -0.13636817, 0.013378032, -0.003163873, 0.02471618, -0.13203168, 0.07989189,
      -0.054477777, 0.059936292, -0.077277765, 0.019922124, -0.15395634, 0.0088137, 0.036947053,
      -0.11207754, 0.042513624, -0.05665606, -0.015827265, 0.12174054
    )
    val weights = model.getParameters()._1
    weights.copy(Tensor[Double](weightData, Array(weightData.size, 1)))

    //    val weightsOri = new ArrayBuffer[Tensor[Double]]()
    //    val weightsNew = new ArrayBuffer[Tensor[Double]]()
    //
    //    val sizeI = hiddenSize * inputSize * 3 * 3
    //    val sizeH = hiddenSize * hiddenSize * 3 * 3
    //    var next = 0
    //    for(i <- 0 until 4) {
    //      val i2g = Tensor[Double](weightData.slice(next, next + sizeI),
    //        Array(1, hiddenSize, inputSize, 3, 3))
    //      weightsOri += Tensor[Double]().resizeAs(i2g).copy(i2g)
    //      next += sizeI
    //      val i2gBias = Tensor[Double](weightData.slice(next, next + hiddenSize),
    //        Array(1, hiddenSize))
    //      weightsOri += Tensor[Double]().resizeAs(i2gBias).copy(i2gBias)
    //      next += hiddenSize
    //      val h2g = Tensor[Double](weightData.slice(next, next + sizeH),
    //        Array(1, hiddenSize, hiddenSize, 3, 3))
    //      weightsOri += Tensor[Double]().resizeAs(h2g).copy(h2g)
    //      next += sizeH
    //    }
    //
    //    val weightsTable = T(weightsOri(0), weightsOri(3), weightsOri(6), weightsOri(9))
    //    val joinWeights = JoinTable[Double](2, 5)
    //    weightsNew += joinWeights.forward(weightsTable)
    //
    //    val biasTable = T(weightsOri(1), weightsOri(4), weightsOri(7), weightsOri(10))
    //    val joinBias = JoinTable[Double](1, 1)
    //    weightsNew += joinBias.forward(biasTable)
    //
    //    weightsNew += weightsOri(2)
    //    weightsNew += weightsOri(5)
    //    weightsNew += weightsOri(8)
    //    weightsNew += weightsOri(11)
    //
    //    weights.copy(Module.flatten[Double](weightsNew.toArray))

    val output = model.forward(input)
    val gradInput = model.backward(input, output).asInstanceOf[Tensor[Double]]

    val expectedGradData = Array(
      0.034878514686072545, 0.04398729956621657, 0.05758099669158952, 0.03437097904494599,
      0.0249094742719872, 0.07606488761784662, 0.09516671025789802, 0.06654659858425119,
      0.01808987869570458, 0.048331800259105696, 0.06467829266973679, 0.05127588163224714,

      0.03994717516008958, 0.038453084506413944, 0.01098428754646074, 0.004754856971752791,
      0.04691100984590687, 0.04733273359693526, 0.032698733073028036, 0.008710296200014872,
      0.016446470040048802, 0.014504827825521021, 0.014398413896711182, 0.005563785069582132,

      0.02508919498362225, 0.04510472680505569, 0.05014850646512659, 0.03689880987099306,
      0.018396925338213875, 0.06508697182696906, 0.07680143870812711, 0.05405187114525272,
      0.011704545443457297, 0.04024544101139729, 0.05209939603631103, 0.04425816399870016,

      0.03849933792429499, 0.026104351589482305, 0.014195299345464618, 0.00429469388842192,
      0.0391442871880561, 0.03748689488743062, 0.034699130158092666, 0.014060432566210765,
      0.019155416496624153, 0.019787616624734043, 0.026792094505297428, 0.008815024575985419,

      0.011581512006221744, 0.06048972732734363, 0.03282915602659322, 0.04561238830468398,
      0.041571255786098195, 0.07116665388420179, 0.08661446915699353, 0.059578562302912356,
      0.020123218760508918, 0.05257043899642402, 0.06063252611339279, 0.04371064348072411,

      0.045257409698179596, 0.03071963993748246, 0.02529315514566803, -0.005670740927325354,
      0.03639680434337409, 0.048487407146512125, 0.02275229820249338, 0.014118614303447524,
      0.02428551150184052, 0.016819297537995985, 0.004203314658387305, 0.013208252415897827,

      0.00941868766149703, 0.054293608957875564, 0.04095605099670765, 0.03558088296182436,
      0.03155322059071331, 0.06819978262289549, 0.0698245353478215, 0.049097324020888955,
      0.024622612693984566, 0.04302580855110488, 0.04833338328255696, 0.03833458594732048,

      0.03859568415306832, 0.019662097401664583, 0.020328833083506203, 0.013087242556229211,
      0.03063108223918263, 0.038899278583615186, 0.029803127563184754, 0.012881705196662431,
      0.022743671177388635, 0.024842756232781472, 0.017745566495452548, 0.010775736917320205,

      0.03143314853307875, 0.026008021124113205, 0.05003786254172548, 0.049226925346696454,
      0.04179540300548766, 0.07442518781716083, 0.06586987783245223, 0.06962442807281881,
      0.019949908429511852, 0.06736192219429586, 0.036991517213459404, 0.06295162984318212,

      0.031871333755458135, 0.029843372655222782, 0.017421399722145476, -0.0023651889661522544,
      0.041743057643546515, 0.021105312244889884, 0.0279956737169246, 0.007435954259695617,
      0.02861718110766123, 0.005404362682476721, 0.023503489726929276, -0.0010247413809265477,

      0.04088956639631173, 0.013630851891249091, 0.05529042930307609, 0.049642832403158754,
      0.0351046539048856, 0.050125216770852105, 0.05088086050445006, 0.06647703410350292,
      0.018242402674039926, 0.04993188252833041, 0.03556726962341712, 0.051466671646896084,

      0.014878868479939288, 0.0391711844039187, 0.03236231195889219, 0.003706375401465365,
      0.022623344961740982, 0.021894028156073286, 0.03376341058584762, 0.014814947908133192,
      0.03352917374224103, 0.009108881997277354, 0.022206440887808925, 0.005025803338779346
    )
    val expectedGrad = Tensor[Double](Storage(expectedGradData), 1,
      Array(batchSize, seqLength, inputSize, 3, 4))

    expectedGrad.map(gradInput, (v1, v2) => {
      TestUtils.conditionFailTest(abs(v1 - v2) < 1e-6)
      v1
    })
  }

  // Tested with torch convlstm
  "A InternalConvLSTM " should "return expected hidden and cell state when batch != 1" in {
    val hiddenSize = 4
    val inputSize = 2
    val seqLength = 2
    val batchSize = 3

    val inputData = Array(
      0.7379243471309989, 0.19769034340444613, 0.7553588318460729, 0.04613053826696778,
      0.5252748330906923, 0.3243104024273151, 0.8989024506441619, 0.6712812402234619,
      0.15323104849092073, 0.10185090916201034, 0.5226652689296567, 0.2899686031416233,

      0.511326269557182, 0.49001743514565177, 0.1984377134313693, 0.7480165633318978,
      0.6211121942699456, 0.6036424737237795, 0.1592963148948352, 0.26366166406420644,
      0.7005873221417939, 0.35208052461842276, 0.725171953811403, 0.5044467614524997,

      0.523452964024349, 0.7937875951399342, 0.6943424385328354, 0.8522757484294406,
      0.2271578105064339, 0.5126336840147973, 0.6541267785298224, 0.3414006259681709,
      0.4863876223334871, 0.9055057744040688, 0.8798793872101478, 0.41524602527888876,

      0.8876018988857928, 0.2600152644808019, 0.8634521212917783, 0.2153581486704762,
      0.4154146887274315, 0.89663329916328, 0.5406127724020381, 0.5893982389701549,
      0.19306165705001244, 0.45133332543350857, 0.9714792090557134, 0.38498402005236265,

      0.006891668205946, 0.5828297372808069, 0.5688474525455758, 0.8823143009035355,
      0.6887655121242025, 0.0027582524956907273, 0.9663642226218756, 0.4108958429337164,
      0.13044029438378613, 0.7476518446247042, 0.27181284499242064, 0.2943235571195709,

      0.5604631859123652, 0.038323627725658116, 0.7603531593097606, 0.27194849168888424,
      0.7267936284527879, 0.9052766088556446, 0.02837551118315562, 0.4574005827539791,
      0.7973179561393601, 0.9980165989408478, 0.25069973544723445, 0.7113158573554784,

      0.11561652478927409, 0.6416543947530068, 0.4621700648670196, 0.029251147823290524,
      0.43078782880788635, 0.25593904336252427, 0.7192673207531932, 0.1756512226290753,
      0.8578302327025378, 0.7098735179715736, 0.5881799635556655, 0.3308673021852163,

      0.6190265867439853, 0.04932089904003756, 0.6528462133452869, 0.9606073952644223,
      0.5669727867139613, 0.30551745863715773, 0.6397990914737653, 0.4658329782779089,
      0.5963362791762975, 0.7691666538081872, 0.3162174933237236, 0.4259981344604039,

      0.3141262752374987, 0.1977960342228966, 0.5334056420410246, 0.40966136495594996,
      0.757202109971997, 0.913316506797173, 0.6599946668051349, 0.3938817528802213,
      0.5209422302691442, 0.6422317952685995, 0.17747939441294336, 0.8074452051580658,

      0.3944410584066397, 0.2527031847980983, 0.39527449481971033, 0.0928484214334575,
      0.4852964142206313, 0.25068274731805196, 0.08329386976314157, 0.4681014271865992,
      0.860995719082069, 0.01158601045989438, 0.31712803268106715, 0.2617924245499301,

      0.934104253955566, 0.24448283400001236, 0.9842141116573005, 0.7934911952785096,
      0.17490418077955516, 0.11254306305960093, 0.20896334129728766, 0.2764503640807836,
      0.24782820561381746, 0.7664813069795481, 0.4206320455881154, 0.6306966327857411,

      0.33697752771986067, 0.7870030511736112, 0.9097909928901393, 0.41198674582123507,
      0.25120817362250203, 0.17776223144093362, 0.49757005026260037, 0.28573572832782146,
      0.8807161778981778, 0.10209784975052816, 0.37642364632639536, 0.6049083416711989)
    val input = Tensor[Double](inputData, Array(batchSize, seqLength, inputSize, 3, 4))

    val rec = new InternalRecurrent[Double]()
    val model = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](inputSize, hiddenSize, 3, 1,
          padding = -1, withPeephole = false)))

    val weightData = Array(
      -0.0697708, 0.187022, 0.08511595, 0.096392, 0.004365, -0.181258, 0.0446674,
      -0.1335725, -0.20553963, -0.06138988, -0.07350091, 0.21952641, -0.20255956, -0.010300428,
      -0.038676325, 0.1958987, -0.13511689, -0.101483345, -0.125394, 0.21549562, 0.009512273,
      0.22363727, -0.12906058, -0.1364867, -0.096433975, 0.05142183, 0.0600333, 0.104275264,
      -0.23061629, -0.14527707, 0.011904705, -0.13122521, -0.028477365, -0.17441525, -0.22748822,
      0.089772895, 0.047597, -0.12802905, 0.023593059, 0.047630176, -0.0041123535, 0.15042211,
      0.22905788, 0.16112402, 0.17447555, -0.17807686, -0.1150537, -0.112298205, -0.09767061,
      -0.06547484, -0.059544224, -0.068848155, 0.017608726, -0.0486288, 0.040728, 0.0032547752,
      -0.062271275, -0.07636171, -0.090193786, -0.1757421, -0.029310212, -0.055856735, 0.2138684,
      0.08065, -0.14784653, -0.11130823, 0.15752073, 0.0417686, 0.07667526, -0.08721331,
      0.12652178, -0.22656773, 0.028447477, 0.059678376, 0.097432226, 0.07587893, 0.018427523,
      0.07969191, -0.16021226, 0.15130767, 0.16185403, 0.09212428, -0.009819705, -0.09506356,
      0.1286456, -0.08805564, 0.046685345, 0.12232006, -0.100213096, -0.15668556, 0.07114366,
      0.095224895, -0.124568574, 0.06927875, 0.0905962, -0.035104837, 0.15567762, -0.043663125,
      -0.07430526, -0.055192, 0.01406816, -0.082332, -0.09430171, -0.1095887, -0.07151577,
      -0.02634421, -0.06150073, 0.11257642, 0.14980435, -0.062605076, -0.12603457, -0.15501663,
      0.02251482, -0.08227526, 0.13384429, 0.024052791, 0.13748798, -0.08544985, -0.036989275,
      0.026403576, -0.136615, 0.05420539, 0.16269544, 0.035181798, 0.15763186, -0.1604577,
      0.1537655, 0.02734131, -0.048081715, 0.100141905, 0.028017784, 0.08105907, 0.019959895,
      0.13988869, 0.16058885, -0.003570326, 0.06591913, -0.10256911, 0.13575412, 0.04774964,
      -0.017293347, -0.048617315, 0.15695612, 0.15765312, -0.047396783, -0.16620952, 0.0025890507,
      -0.13422322, -0.03875675, -0.075357996, 0.113039605, 0.13345407, 0.09567941, -0.003772,
      0.07441882, 0.04021747, 0.12041045, -0.042105403, -0.027613033, 0.15320867, -0.12912026,
      -0.081750855, -0.0344171, -0.15512398, 0.15219747, 0.036528654, -0.012755581, 0.098534,
      -0.07061299, 0.02883929, 0.14481406, -0.051582605, 0.10316327, 0.085615724, 0.06536975,
      -0.054357443, 0.02749899, -0.013213737, 0.057099275, 0.15802467, -0.05081968, -0.07198317,
      0.11493357, -0.0012803806, 0.11840431, 0.10919253, -0.10307259, 0.087982856, 0.06715956,
      0.03439658, -0.1251883, -0.16122015, 0.11468333, 0.15124878, -0.040252376, 0.13959402,
      -0.018218568, -0.03417238, -0.07071258, 0.15031807, -0.09312864, -0.014361585, 0.009083145,
      0.07651518, -0.030849354, 0.053097464, 0.02317304, -0.0126761, -0.10731614, 0.08843881,
      -0.058363467, -0.07192067, 0.13913071, -0.07697743, 0.15923063, -0.08419231, 0.017478677,
      -0.08418075, -0.057064693, 0.0024510117, -0.20928818, -0.22167638, -0.038345862, 0.03438525,
      0.11347725, -0.12304, 0.026768617, 0.045132592, 0.091074154, 0.13448715, 0.11804616,
      -0.22657603, -0.016182138, -0.1331919, 0.05141501, 0.015872177, -0.12630826, 0.21568011,
      -0.10292801, 0.13611461, -0.08374142, -0.22699684, 0.16571483, -0.098663375, -0.018467197,
      -0.15427141, -0.15015155, 0.10223335, -0.14016786, -0.10880828, -0.21908437, 0.14608948,
      0.07250339, 0.06662375, -0.18800929, 0.11404393, 0.13704747, 0.14116052, 0.10486333,
      -0.010664585, -0.11811825, -0.1724059, 0.15996984, -0.17623067, -0.055978876, -0.10195447,
      -0.17426933, 0.009317461, -0.23025058, -0.22655061, 0.1504219, -0.22794038, 0.19736658,
      -0.076756656, -0.16812365, -0.22800718, 0.17344427, -0.12931758, 0.10014104, -0.13550109,
      -0.23511806, -0.06651987, 0.19619618, 0.097913995, 0.14484589, -0.20718974, 0.20920905,
      0.18459858, -0.008639049, -0.22041525, 0.08012527, 0.14633249, -0.024920763, -0.18607515,
      0.07662353, -0.15454617, 0.067585476, 0.05524932, -0.15291865, 0.12737663, 0.05814569,
      -0.11415055, 0.11919394, -0.06319863, 0.1465731, 0.054857183, -0.15075083, 0.090675876,
      0.1525343, -0.0066932747, -0.048541967, 0.06132587, -0.079331905, 0.11314261, 0.14027406,
      -0.0266242, -0.016292417, 0.07795509, 0.020753743, -0.10986114, 0.10756251, 0.02036946,
      0.026220735, -0.11005689, 0.10311518, 0.07109452, -0.09970161, 0.068307705, 0.11119034,
      -0.06424175, -0.0012396448, -0.11550802, -0.06943571, -0.110153, -0.041444167, -0.12524629,
      0.15868594, 0.008897657, -0.10843479, 0.15759167, -0.09669543, -0.08299825, -0.0937801,
      -0.020804988, -0.08680972, 0.083160855, 0.029616985, -0.017982747, 0.0037287925, 0.097527005,
      0.09538205, -0.0932, -0.097054094, 0.10397664, 0.12322543, -0.06448696, -0.12847184,
      0.050058555, 0.09502069, 0.08681986, -0.14003497, 0.03627888, -0.075629145, -0.095788166,
      -0.08410784, 0.13308963, -0.007147816, -0.16363329, 0.12797672, 0.124641, -0.05630061,
      0.0064241965, 0.077181205, -0.1251426, 0.08616565, 0.1477562, -0.04511368, 0.029885028,
      0.057127535, -0.08563146, 0.13729702, -0.10859255, -0.102196366, 0.008430395, -0.0945447,
      0.10205625, -0.07343792, 0.16189432, -0.1300748, 0.08548705, 0.16390403, -0.02669807,
      -0.058629803, -0.05904906, -0.016605929, 0.14874554, 0.014934211, -0.09052281, 0.0579616,
      -0.041529182, 0.09614261, 0.15888576, -0.11366321, -0.102919176, -0.1167308, -0.011413716,
      0.07176415, -0.03216456, -0.063260436, 0.059609246, 0.16423965, 0.0052398313, 0.1286797,
      -0.0381152, -0.009582818, 0.004786132, -0.1019815, 0.043783717, 0.05244485, -0.06435464,
      -0.16259833, -0.100482024, 0.0587321, -0.052555863, 0.032503795, -0.1606384, -0.14574,
      -0.05185242, -0.08184071, 0.15766397, 0.09867271, 0.08309498, -0.15646282, 0.15911676,
      -0.008041214, -0.07785257, 0.06866316, -0.07157379, 0.13319956, -0.066218115, 0.0138255,
      -0.076073825, -0.14936924, 0.10676395, -0.4842985, 0.083836384, -0.21565008, -0.2306193,
      0.0017399695, -0.07744393, -0.2044993, -0.21714376, 0.0077707577, 0.14650588, 0.19233301,
      -0.04317506, -0.058397546, 0.15633541, -0.115028, -0.044130307, 0.063888475, 0.21123467,
      0.20039539, -0.045635425, 0.040344927, -0.12061099, 0.13238785, -0.11554383, 0.012527357,
      -0.04936022, -0.223834, 0.2067501, 0.035001267, -0.121593, 0.08469669, -0.15821323,
      0.013301196, 0.19869077, -0.18677086, -0.09790556, 0.18662173, -0.216591, 0.13041325,
      0.13628985, 0.042848308, -0.031125685, -0.22374651, -0.087204315, 0.05124186, -0.22576457,
      0.014185649, -0.0899537, -0.015126135, -0.10904176, -0.212513, -0.19453013, 0.0071554612,
      -0.07960433, -0.20750536, -0.22908148, 0.066988595, -0.11946863, -0.20373446, -0.03756,
      -0.15693687, 0.015695922, -0.19193731, 0.035843078, 0.07994549, 0.025597025, -0.10725631,
      -0.11276663, -0.1882937, 0.019561082, 0.0135140605, 0.041632164, 0.0010907603, -0.06264914,
      -0.016213655, 0.0937373, 0.094795, -0.1173104, -0.21944033, -0.09396857, 0.13556847,
      -0.09024931, 0.1276821, -2.7715965E-4, 0.12017726, 0.13998412, 0.13809435, 0.16587347,
      0.04789949, -0.08513931, 0.07294201, -0.08220003, -0.15560868, 0.14816408, -0.09582949,
      0.051776934, -0.011485172, 0.14832942, 0.10104054, 0.080303155, 0.0034141147, 0.14833276,
      0.09612207, 0.11273294, 0.13111332, -0.00879518, -0.1397018, -0.10093753, -0.00945932,
      -0.032682095, -0.14018348, 0.050238717, 0.09185889, -0.14419281, 0.09613244, -0.13719763,
      0.04358094, -0.15398286, -0.116741166, -0.11954482, 0.14914127, -0.126483, -0.026603939,
      0.15768388, 0.06356159, 0.05631903, 0.0101217795, 0.15248485, -0.14745563, -0.0145869935,
      0.0382958, 0.057202652, -0.14191794, 0.059604887, 0.011006361, -0.07016107, 0.076446384,
      0.013760659, -0.068240955, 0.0037634, 0.12695941, 0.041081227, 0.10223117, 0.11603621,
      -0.06294605, -0.010134418, -0.006934982, 0.11731349, -0.10002373, 0.14468494, 0.006046706,
      -0.11748926, -0.13269922, 0.08922616, 0.076726876, 0.079133116, -0.13795392, 0.05776867,
      -0.12632991, -0.16351144, -0.067499354, 0.047223303, 0.063164465, -0.0149828065, -0.031813424,
      -0.08393954, -0.067819, 0.081516, -0.1065244, 0.14492081, 0.11396905, -0.10664382,
      -0.0098184915, 0.08660889, -0.16464078, 0.07709077, -0.1493178, 0.017629929, 0.08108806,
      -0.057861995, 0.05144662, -0.019507658, 0.098744385, 0.14157839, -0.101155385, -0.1155548,
      -0.1539434, 0.07039324, -0.015811022, 0.15094946, -0.16115923, 0.116900794, 0.11721963,
      0.020760974, 0.0040808455, -0.0896887, 0.013347261, 0.11278092, -0.07966485, -0.094330534,
      -0.15664604, 0.015197758, 0.12119024, -0.05060158, 0.06654976, -0.13198644, -0.1457269,
      -0.13899888, 0.038908076, -0.13269305, -0.11445787, 0.021789772, 0.027084751, 0.01323522,
      -0.12667863, 0.026683968, 0.04916361, 0.0086855, 0.15367854, 0.031549584, -0.0036370864,
      0.08499007, -0.10802871, 0.03548985, -0.17660856, -0.068241306, 0.15097389, 0.16520916,
      0.024556529, 0.0017257226, 0.17331718, 0.196117, 0.19437543, 0.19648184, -0.1331118,
      -0.21632133, -0.18020143, 0.12856491, 0.1344524, 0.11382166, 0.064181186, 0.14279565,
      -0.08350899, -0.2256594, -0.13126723, 0.043258272, -0.021165192, 0.089386486, -0.09204444,
      -0.0960608, -0.037649803, 0.22336064, -0.031554904, 0.124656096, -0.025671339, -0.1065685,
      0.0453102, -1.68393E-5, 0.22479524, 0.046631828, 0.007860622, -0.22629729, -0.13721013,
      0.22810946, -0.12107487, 0.022246245, 0.17803338, 0.2083739, 0.18673882, -0.1917718,
      0.07565709, 0.120346785, -0.14759375, -0.1377154, 0.038963128, 0.22792713, -0.2159763,
      -0.006619736, 0.2313753, -0.04800687, -0.1518908, 0.18948461, 0.1076321, -0.11479616,
      -0.0212803, 0.14886868, -0.22150691, 0.089185275, -0.040394045, 0.13415302, 0.21480684,
      0.0878023, 0.106930904, -0.18570949, -0.013600573, 0.11532847, 0.11659276, 0.112827145,
      -0.1062416, 0.066263296, -0.08610482, 0.105527066, -0.058957383, -0.15528603, -0.009521967,
      0.011328606, -0.06197259, -0.13204348, 0.08675131, -0.113543, -0.01445269, 0.02258719,
      -0.008030752, -0.093486756, -0.07264881, 0.09213272, 0.07619277, 0.16032794, -0.026074272,
      0.066076815, -0.10525776, 0.16016503, 0.03144442, -0.023126643, 0.05451808, 0.022852356,
      -0.096872106, -0.030566314, -0.16589479, 0.0905115, -0.1473723, -0.12166525, 0.078377604,
      0.13821222, -0.078764655, 0.14731602, -0.08815969, -0.0236424, -0.0355236, -0.09844407,
      -0.012984, 0.047678906, -0.038449008, 0.08535368, 0.15068671, -0.008833185, -0.09007217,
      0.112541415, -0.06900989, -0.102243155, -0.050330114, -0.13928314, -0.041724514, 0.054797813,
      -0.16646549, 0.13796, 0.12394269, 0.020277899, -0.013631716, -0.09424963, -0.13880578,
      0.08686539, -0.15236098, 0.05722864, -0.02671615, -0.06085055, 0.09522983, -0.03990184,
      -0.06986189, -0.014213024, -0.1377847, 0.08251909, 0.0143873375, -0.0860864, -0.0640099,
      -0.06048214, -0.030843036, 0.10346391, -0.14285919, 0.1575129, 0.11078764, -0.09553229,
      -0.15557009, -0.039680153, -0.02489069, 0.03813003, 0.1080799, 0.07591443, 0.1631084,
      0.04714953, -0.10192201, -0.12497483, 0.038626827, -0.07361671, -0.097818114, -0.14928903,
      -0.14453772, 0.10313048, 0.11320499, -0.063832685, 0.011636197, -0.16415314, -0.142816,
      0.041214544, -0.119791135, 0.10883034, -0.14729027, -0.122481905, 0.08507194, -0.088145964,
      -0.015075706, 0.06492, -0.16094309, 0.12339206, 0.011586048, 0.1321518, -0.05177626,
      0.033773363, -0.13636817, 0.013378032, -0.003163873, 0.02471618, -0.13203168, 0.07989189,
      -0.054477777, 0.059936292, -0.077277765, 0.019922124, -0.15395634, 0.0088137, 0.036947053,
      -0.11207754, 0.042513624, -0.05665606, -0.015827265, 0.12174054
    )
    val weights = model.getParameters()._1
    weights.copy(Tensor[Double](weightData, Array(weightData.size, 1)))

    //        val weightsOri = new ArrayBuffer[Tensor[Double]]()
    //        val weightsNew = new ArrayBuffer[Tensor[Double]]()
    //
    //        val sizeI = hiddenSize * inputSize * 3 * 3
    //        val sizeH = hiddenSize * hiddenSize * 3 * 3
    //        var next = 0
    //        for(i <- 0 until 4) {
    //          val i2g = Tensor[Double](weightData.slice(next, next + sizeI),
    //            Array(1, hiddenSize, inputSize, 3, 3))
    //          weightsOri += Tensor[Double]().resizeAs(i2g).copy(i2g)
    //          next += sizeI
    //          val i2gBias = Tensor[Double](weightData.slice(next, next + hiddenSize),
    //            Array(1, hiddenSize))
    //          weightsOri += Tensor[Double]().resizeAs(i2gBias).copy(i2gBias)
    //          next += hiddenSize
    //          val h2g = Tensor[Double](weightData.slice(next, next + sizeH),
    //            Array(1, hiddenSize, hiddenSize, 3, 3))
    //          weightsOri += Tensor[Double]().resizeAs(h2g).copy(h2g)
    //          next += sizeH
    //        }
    //
    //        val weightsTable = T(weightsOri(0), weightsOri(3), weightsOri(6), weightsOri(9))
    //        val joinWeights = JoinTable[Double](2, 5)
    //        weightsNew += joinWeights.forward(weightsTable)
    //
    //        val biasTable = T(weightsOri(1), weightsOri(4), weightsOri(7), weightsOri(10))
    //        val joinBias = JoinTable[Double](1, 1)
    //        weightsNew += joinBias.forward(biasTable)
    //
    //        weightsNew += weightsOri(2)
    //        weightsNew += weightsOri(5)
    //        weightsNew += weightsOri(8)
    //        weightsNew += weightsOri(11)
    //
    //        weights.copy(Module.flatten[Double](weightsNew.toArray))
    val expectedCellData = Array(
      -0.14987348457222294, -0.22105992474941866, -0.3387238605194439, -0.39446403835780186,
      -0.316453669906271, -0.37140873935818197, -0.4794204029633373, -0.43318998938865677,
      -0.31715138375541807, -0.36546697609966056, -0.36668646927670884, -0.2531378632113071,

      0.12164367622452013, 0.06375662877054206, 0.10558023052572624, 0.09296068067631055,
      0.10350287214650072, -0.043105084379426555, -0.002532373132315092, -0.017107651118940928,
      0.0783054058646045, 0.09983730266019218, 0.15692887903321268, 0.0010369583032553142,


      -0.11568265586055029, -0.05157072657958458, -0.030545098446643958, 0.07097465680094239,
      -0.19039540889885484, -0.19104510993259888, -0.02513773533630481, 0.033952703953675126,
      -0.22412212420921007, -0.24804429480520482, -0.12672816312820606, -0.022755560537669244,


      -0.30986135292585515, -0.4251493103752495, -0.2889602199722825, -0.17273356360945663,
      -0.3580787788172773, -0.5158149093281363, -0.5032594569503482, -0.27377998399438014,
      -0.14520932544812054, -0.3487598323162634, -0.36089232520218323, -0.3875967229666661,


      -0.10052985037535103, -0.30405218242711163, -0.18654948355613826, -0.27308495250404397,
      -0.3780858084639802, -0.348063759578918, -0.5343725340433997, -0.3798054993937692,
      -0.2756440427688284, -0.38768793305018395, -0.3235976734640538, -0.24045597962433007,


      0.0341441391957715, 5.299130157097875E-4, 0.15185104088503992, 0.08317049353052722,
      0.11869821799103805, -0.015495938546971158, -0.03628157899463032, 0.03548216233650562,
      0.16615806661829663, 0.11809127970778976, 0.022987783943591498, 0.053172762021747644,


      -0.1608682545104523, 0.05471790389217375, -0.05140250835841573, 0.012483754945225009,
      -0.19132561788613034, -0.11758276843982571, -0.04001330599272664, -0.06993179715262401,
      -0.18292331023592934, -0.23018407727971535, -0.04965212591333139, -0.06706774082467053,


      -0.20722497302489531, -0.3163177647903549, -0.20047416971047505, -0.09940467310699157,
      -0.40391310591443685, -0.3453615998739744, -0.3436142871456762, -0.30579940792439053,
      -0.21393999767341623, -0.38632194331237324, -0.338132501410083, -0.2799806913472954,


      -0.2982647860718788, -3.718604397006525E-4, -0.2748632767327928, -0.3613061886167156,
      -0.35767327847568825, -0.4899646165682648, -0.4672077926890158, -0.37567367161152376,
      -0.1962868448170617, -0.3926117801022085, -0.2074747021487374, -0.3643094309050142,


      0.15985278516427298, 0.11449180058788376, 0.11367214355435092, 0.10212329736136479,
      0.11600188662627217, -0.08044840449625897, 0.042658116444042354, -0.06944713691444436,
      0.15348701605457232, -0.02271487773115937, 0.13052044218592151, 0.10538763744284842,


      0.06161551500417564, -0.1439853127761765, -0.03235769175109047, 0.05306321317925293,
      -0.023773273902641757, -0.18127460458925815, -0.08972730903797148, 0.10524359137085819,
      -0.22209026327385234, 0.015056602926191405, -0.19742347876502547, 0.01990844614402051,


      -0.18828070103927522, -0.21929051591262086, -0.3450388474490299, -0.21525044509990426,
      -0.36115984252761013, -0.30653256673883555, -0.4084202398086612, -0.3277215031561226,
      -0.1867237630660016, -0.3244769101591504, -0.23294822238292623, -0.26057123158689666
    )

    val output = model.forward(input).asInstanceOf[Tensor[Double]]
    val state = rec.getHiddenState()
    val hiddenState = state.toTable.apply(1).asInstanceOf[Tensor[Double]]
    val cell = state.toTable.apply(2).asInstanceOf[Tensor[Double]]
    hiddenState.map(output.select(2, seqLength), (v1, v2) => {
      TestUtils.conditionFailTest(abs(v1 - v2) == 0)
      v1
    })

    cell.map(Tensor[Double](expectedCellData, Array(batchSize, hiddenSize, 3, 4)),
      (v1, v2) => {
        TestUtils.conditionFailTest(abs(v1 - v2) < 1e-10)
        v1
      })

    rec.setHiddenState(state)
    model.forward(input)
  }

  "A InternalConvLSTM2D " should "with set state should generate different output" in {
    val hiddenSize = 4
    val inputSize = 2
    val seqLength = 2
    val batchSize = 3

    val input = Tensor[Double](batchSize, seqLength, inputSize, 3, 4).rand()

    val seed = 890
    RNG.setSeed(seed)
    val rec = new InternalRecurrent[Double]()
    val model = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](inputSize, hiddenSize, 3, 1,
          padding = -1, withPeephole = false)))

    RNG.setSeed(890)
    val rec2 = new InternalRecurrent[Double]()
    val model2 = Sequential[Double]()
      .add(rec2
        .add(new InternalConvLSTM2D[Double](inputSize, hiddenSize, 3, 1,
          padding = -1, withPeephole = false)))

    val output = model.forward(input).asInstanceOf[Tensor[Double]]

    rec2.setHiddenState(T(Tensor[Double](batchSize, hiddenSize, 3, 4).rand,
      Tensor[Double](batchSize, hiddenSize, 3, 4).rand))
    val output2 = model2.forward(input).asInstanceOf[Tensor[Double]]

    output.map(output2,
      (v1, v2) => {
        TestUtils.conditionFailTest(abs(v1 - v2) != 0)
        v1
      })
  }

  "InternalConvLSTM2D L2 regularizer" should "works correctly" in {
    import com.intel.analytics.bigdl.numeric.NumericDouble

    val hiddenSize = 5
    val inputSize = 3
    val seqLength = 4
    val batchSize = 1
    val kernal = 3

    val state1 = T("learningRate" -> 0.1, "learningRateDecay" -> 5e-7,
      "weightDecay" -> 0.1, "momentum" -> 0.002)
    val state2 = T("learningRate" -> 0.1, "learningRateDecay" -> 5e-7,
      "weightDecay" -> 0.0, "momentum" -> 0.002)

    val criterion = new TimeDistributedCriterion[Double](new MSECriterion[Double])

    val input = Tensor[Double](batchSize, seqLength, inputSize, 3, 3).rand
    val labels = Tensor[Double](batchSize, seqLength, hiddenSize, 3, 3).rand

    val rec = new InternalRecurrent[Double]()
    val model1 = Sequential[Double]()
      .add(rec
        .add(new InternalConvLSTM2D[Double](
          inputSize,
          hiddenSize,
          kernal,
          1,
          padding = -1,
          withPeephole = true)))

    val (weights1, grad1) = model1.getParameters()

    val model2 = Sequential[Double]()
      .add(new InternalRecurrent[Double]()
        .add(new InternalConvLSTM2D[Double](
          inputSize,
          hiddenSize,
          kernal,
          1,
          padding = -1,
          wRegularizer = L2Regularizer(0.1),
          uRegularizer = L2Regularizer(0.1),
          bRegularizer = L2Regularizer(0.1),
          cRegularizer = L2Regularizer(0.1),
          withPeephole = true)))

    val (weights2, grad2) = model2.getParameters()
    weights2.copy(weights1.clone())
    grad2.copy(grad1.clone())

    val sgd = new SGD[Double]

    def feval1(x: Tensor[Double]): (Double, Tensor[Double]) = {
      val output = model1.forward(input).toTensor[Double]
      val _loss = criterion.forward(output, labels)
      model1.zeroGradParameters()
      val gradInput = criterion.backward(output, labels)
      model1.backward(input, gradInput)
      (_loss, grad1)
    }

    def feval2(x: Tensor[Double]): (Double, Tensor[Double]) = {
      val output = model2.forward(input).toTensor[Double]
      val _loss = criterion.forward(output, labels)
      model2.zeroGradParameters()
      val gradInput = criterion.backward(output, labels)
      model2.backward(input, gradInput)
      (_loss, grad2)
    }

    var loss1: Array[Double] = null
    for (i <- 1 to 100) {
      loss1 = sgd.optimize(feval1, weights1, state1)._2
      println(s"${i}-th loss = ${loss1(0)}")
    }

    var loss2: Array[Double] = null
    for (i <- 1 to 100) {
      loss2 = sgd.optimize(feval2, weights2, state2)._2
      println(s"${i}-th loss = ${loss2(0)}")
    }

    weights1 should be(weights2)
    loss1 should be(loss2)
  }
}
